--- Code generation for method calls
module frege.compiler.gen.java.MethodCall where

import Data.TreeMap(TreeMap, values)
import Data.List(elemBy)

import Compiler.Utilities as U()

import Compiler.classes.Nice(nice, nicer)

import Compiler.types.AbstractJava
import Compiler.types.Types(unST, Sigma, Tau, TauT, ForAll, RhoTau, RhoFun)
import Compiler.types.Symbols(SymbolT)
import Compiler.types.Global
import Compiler.types.JNames(JName, memberOf)
import Compiler.types.Strictness()

import Compiler.common.Types as CT
import Compiler.common.JavaName

import frege.compiler.tc.Methods (NIKind, niKind)

import Compiler.gen.java.Common
import Compiler.gen.java.Bindings 

returnTau sigma = (fst • U.returnType) (Sigma.rho sigma)
sigmaTau f (ForAll [] (RhoTau [] ty)) = f ty
sigmaTau f (ForAll [] rhofun) = sigmaTau f (ForAll [] (tauRho rhofun))
sigmaTau f _ = Nothing

niSpecial g ty
     | Just _ <- U.isUnit ty        = true
     | Just _ <- U.isMaybe ty       = true
     | Just _ <- U.isException g ty = true
     | Just _ <- unST ty            = true
     | otherwise = false
 

--- Tells if a native symbol is wrapped 
wrapped g (sym@SymV {nativ = Just item, throwing}) = not (null throwing) 
                                                        || niSpecial g rty
                                                        || not (null (wildReturn g sym))
     where
         (rty, _) = U.returnType sym.typ.rho
wrapped g SymV {} = false
wrapped g SymD {} = false
wrapped g _ = error "wrapped: no symv"
 
{--
    Tell if a native function must be called through its wrapper.
  
    This is the case when it is wrapped for some other reason
    than that the return type is 'Maybe'.
  -}
wrappedOnly g (sym@SymV {nativ = Just _, throwing}) 
        = not (null throwing) 
            || niSpecial g rty && isNothing (U.isMaybe rty)
            || not (null (wildReturn g sym))
    where
        (rty, _) = U.returnType sym.typ.rho
wrappedOnly g sym = error "wrappedOnly - no native function"
-- 
-- 
--- returns a binding for a direct call of a native method
nativeCall ∷ Global → Symbol → TreeMap String Tau → [JExpr] → Binding
nativeCall g (sym@SymV {nativ = Just item, gargs}) subst aexs = newBind g bsig (call jrty args)
    where
        (rty, sigmas) = U.returnType sym.typ.rho
        taus  = [ tau | Just tau <- map U.sigmaAsTau sigmas ]
        brty  = substTau subst (baserty rty)
        bsig  = U.tauAsSigma brty
        targs = map (boxed . tauJT g . substTau subst) gargs 
        args | [tau] <- taus, Just _ <- U.isUnit tau = []   -- no arguments
             | otherwise = zipWith (argEx g)  aexs taus
        bjt  = tauJT g brty
        jrty = strict  bjt
        -- retmode = maxStrict jrty         
        
        argEx g arg tau
             | Just x <- U.isMaybe tau = JQC checknothing (JAtom "null") evex
             -- Just x <- arrayTau g tau = JCast x bind.jex
             | otherwise = arg
             where
                 -- sbind = primitiveBind  bind
                 checknothing = JBin con "==" (JAtom "0")
                 con = JInvoke (JX.static "constructor" jtRuntime) [arg]
                 evex = JInvoke evm1 []
                 evm1 = JExMem m1ex "call" []
                 m1ex = JExMem just "mem1" []
                 just = JInvoke (JExMem arg "asJust" []) []
        baserty r
             | Just (_, x) <- unST r       = baserty x
             | Just (_, x) <- U.isException g r = baserty x
             | Just x <- U.isMaybe r       = baserty x
             | Just _ <- U.isUnit r        = r
             | otherwise                   = r
        call jrty args = case niKind item of
             NIOp -> case args of
                 [a,b] -> JBin a item b
                 [a]   -> JUnop item a
                 _     -> JAtom "null"           -- error was flagged before
             NINew -> JNew jrty args
             NIExtends ->
               let evalStG :: Global -> StG a -> a
                   evalStG g st = fst $ st.run g
                   x = do g <- getST
                          si <- symInfo sym
                          let name  = (head si.argSigs).rho.tau.name
                              irsym = unJust $ g.findit name
                              nms = mapMaybe (_.name) [ fld | x@SymD{} <- values irsym.env, fld <- x.flds ]
                          return $ flip mapMaybe nms $ \fldnm -> do
                              nativrsym <- g.findit $ si.retSig.rho.tau.name
                              nativsym <- TreeMap.lookup fldnm nativrsym.env
                              nativnm <- nativsym.nativ
                              let nativsi = evalStG g $ symInfo nativsym
                              fldsym <- TreeMap.lookup fldnm irsym.env
                              pure $ wrapIRMethod g (head args) (head si.argJTs) nativsi nativnm fldnm fldsym
               in JNewClass jrty [] (evalStG g x)
             NICast -> case args of
                 [a] -> JInvoke (JAtom item) args    -- was: JCast (Ref (JName "" item) []) a
                 _   -> JAtom "null"
             NIMethod -> case args of
                 (a:as) -> case item of
                     "clone" -> JCast jrty (JInvoke (JExMem a item []) as) -- due to java brain damage
                     _ -> JInvoke (JExMem a item targs) as
                 _ -> JAtom "null"
             NIMember -> case args of
                 [a] -> (JExMem a (tail item) [])
                 _ -> JAtom "null"
             NINewArray
                | jrty.{gargs?} -> JCast jrty (JCast Something (JNewArray jrty (head args)))
                | otherwise = JNewArray jrty (head args)
             NIStatic -> case sigmas of
                 (_:_) -> case item =~~ ´^(.+)\.([\w\d\$_]+)$´ of
                        [_, Just qual, Just base] →
                            JInvoke (JX.static base Nativ{typ=qual, gargs=targs, generic = true}) args
                        _ → JInvoke (JAtom item) args
                 _  | item ~ ´^.+\.class$´, 
                      Nativ {typ="java.lang.Class", gargs=[x]} ← jrty,
                      x.{gargs?}, not (null x.gargs),
                      not x.{generic?} || x.generic
                        = JCast jrty (JCast Something JStMem{jt=rawType x, name="class", targs=[]}) 
                    | otherwise = JAtom item
             NIArrayGet -> case args of
                [a,b] -> JArrayGet a b
                _     -> JAtom "bad array get"      -- error was flaggend before
             NIArraySet -> case args of
                [a,b,c] -> JBin (JArrayGet a b) "=" c
                _     -> JAtom "bad array set"      -- error was flagged before 
nativeCall g sym subst aexs = error ("nativeCall: no function " 
    ++ show sym.pos.first.line
    ++ ", " ++ nicer sym g)
-- 
wrapCode g jreturn rtau (sym@SymV {nativ = Just item, throwing}) subst aexs
    | Just (stau, atau) <- unST rtau = let
            sjt     = tauJT g stau          -- type #1 for parameterization of ST s a
            ajt     = tauJT g atau          -- return type of the ST action
            ssig    = ForAll [] (RhoTau [] stau)
            mktup x = JReturn (mkpure sjt ajt x)
            code    = wrapCode g mktup atau sym subst aexs 
            try     = JBlockX "try" code
            rbody
                | null throwing   = code
                | otherwise = try : catches
                where
                    catches = map mkCatch throwing
                    mkCatch t = JBlockX (catch t) [JThrow wrap]
                    wrap = (JX.invoke [JAtom "ex"] . JX.static "wrapIfNeeded") jtWrapped

            ret     = jreturn fun -- (mkst sjt ajt fun) 
            fun     = JCast{jt = lambda, 
                            jex = JLambda{
                                fargs = [(attrFinal, ssig, lazy sjt, "_state")], 
                                code  = Right (JBlock rbody)}}
            lambda  = lambdaType (st sjt ajt)
        in pure ret
    | Just (exs, mtau) <- U.isException g rtau = let
            jexs    = autoboxed (tauJT g exs)
            jmtau   = autoboxed (tauJT g mtau)   
            code    = wrapCode g (jreturn . right jexs jmtau) mtau sym subst aexs
            try     = JBlockX "try" code
            mkCatch exs go = case U.isException g exs of
                    Just (lty, rty) -> JBlockX (catch rty) r : mkCatch lty (go . left jlty jrty)
                        where
                            jlty = tauJT g lty
                            jrty = tauJT g rty
                            r = [(jreturn . go . right jlty jrty) (JAtom "ex")]
                    Nothing -> [JBlockX (catch exs) [(jreturn . go) $ (JAtom "ex")]]


            -- catch   = JBlockX "catch (Exception ex)" [left]
        in try : reverse ( mkCatch exs (left jexs jmtau))
    | Just atau <- U.isMaybe rtau =  let
                        mkmb  =  JInvoke (JX.static "_toMaybe" base) [bind.jex]
         in [jreturn mkmb]
    | Just _ <- U.isUnit rtau = let
            unit   = JX.static "Unit" tunit
         in [JEx bind.jex, jreturn unit]
    | otherwise = [jreturn (strictBind g bind).jex]
    where
        st a b  = Func [boxed a, boxed b]
        -- stpure a b = (JX.staticMember (memberOf (JName "PreludeBase" "TST") "$return")).{targs=[boxed a, boxed b]}
        -- stst   a b = Ref (memberOf (JName "PreludeBase" "TST") "DST")   [boxed b, boxed a]
        mkpure a b x = {- JInvoke (stpure a b) [x]  -} thunkWhenNeeded b x
        -- mkst   a b x = x -- JInvoke (JX.static "mk" (stst   a b)) [x]
        wbind   =  nativeCall g sym subst aexs     -- no substitutions!
        bind 
            | not (null (wildReturn g sym))  = bnd.{jex ← JCast bnd.jtype}
            | otherwise = wbind
            where bnd = strictBind g wbind
        base    = nativ "PreludeBase" []
        tunit   = nativ "PreludeBase.TUnit" []
        tright a b = nativ "PreludeBase.TEither.DRight" [a,b]
        tleft  a b = nativ "PreludeBase.TEither.DLeft"  [a,b]
        right a b x = JInvoke (JX.static "mk" (tright a b)) [thunkWhenNeeded b x]
        left  a b x = JInvoke (JX.static "mk" (tleft  a b)) [thunkWhenNeeded a x]
        catch rty = case tauJT g rty of
                        Nativ{typ, gargs} -> "catch (" ++ typ ++ " ex)"
                        other -> error ("bad exception type " ++ show other)        
wrapCode g jreturn rtau sym _ _ = error "wrapCode: no SymV"
 
 
{--
    code for native functions and/or members
  -}
methCode :: Global -> Symbol -> SymInfo8 -> [JDecl]
methCode g (sym@SymV {nativ = Just item}) si = [
        JComment ((nice sym g) ++ "  " ++ show sym.strsig ++ "  " ++ show sym.rkind),
        JComment (nicer sym.typ g),
        JComment ("the following type variables are probably wildcards: " ++ joined ", " (map _.var wildr)),
        JComment item] ++
                (if arity then defs 
                 else if wrapped g sym || niKind item != NIStatic 
                    then [member]
                    else [])
    where
        rjt         = tauJT g rty
        rArgs       = lambdaArgDef g attrFinal si.argSigs (getArgs g)
        wArgs       = argDefs attrFinal si (getArgs g)
        wildr       = wildReturn g sym
        name        = symJavaName g sym                        -- X.foo
        ftargs      = targs g sym.typ                                -- <a,b,c>
        args        = if haswrapper then wArgs else rArgs
        haswrapper  = arity && wrapped g sym -- (not (null bnds))
        jreturn     = if arity  then JReturn else JEx
        bndWcode  x = newBind g (ForAll [] (RhoTau [] rty))  x
        attr
            | not (null wildr) = attrs [JUnchecked, JPublic, JStatic, JFinal]
            | unsafeCast g sym = attrs [JUnchecked, JPublic, JStatic, JFinal]
            | otherwise        = attrTop
 
        wcode       = if wrapped g sym
                        then wrapCode g jreturn rty sym TreeMap.empty (map (_.jex . instArg g) args)
                        else let
                                bind = nativeCall g sym TreeMap.empty (map (_.jex . instArg g) args)
                            in [jreturn bind.jex]
        wrappers    = if haswrapper then [{- inst, -} wrapper] else [{-inst-}]
        wrapper     = JMethod {attr,
                                 gvars = ftargs, jtype = si.returnJT, name = name.base,
                                 args = wArgs, body = JBlock wcode}
        defs        = wrappers
        unex  [(JEx x)] = (strictBind g (bndWcode x)).jex
        unex  _ = error "unex: need list with 1 ex"
         
        member = JMember {attr = attrTop,
                          jtype = rjt,
                          name = (symJavaName g sym).base,
                          init = Just (unex wcode)}
 
        (rty, atys) = U.returnType sym.typ.rho
        arity       = not (null atys) || not (null sym.typ.bound)
 
methCode g sym _ = Prelude.error ("line " ++ show sym.pos.first.line 
                    ++ ": can not compile " ++ nice sym g)

{--
    Check if the type has a type variable that is only mentioned in the return type,
    and is *not* mentioned in the generic arguments.

    If this is the case, we have probably something like:

    >  Class<?> loadClass(String name)

    and we need to cast the result. 
-}
wildReturn ∷ Global → Symbol → [Tau]
wildReturn g (symv@SymV{})  = [ v | v@TVar{} ← values (U.freeTauTVars [] TreeMap.empty ret),
                                    not (stvar v.var),
                                    not (elemBy (using _.var) v sigvars),
                                    not (elemBy (using _.var) v itemvars)
                            ]
    where
        (ret, sigs) = U.returnType symv.typ.rho
        -- identify ST phantom type variable, if any
        stvar = case unST ret of
            Just (tv@TVar{}, _) → (tv.var ==)
            other → const false
        sigvars  = concatMap (values . U.freeRhoTVars [] TreeMap.empty . _.rho) sigs
        itemvars = concatMap (values . U.freeTauTVars [] TreeMap.empty) symv.gargs
wildReturn _ _ = []


wrapIRMethod :: Global -> JExpr -> JType -> SymInfo8 -> String -> String -> Symbol -> JDecl
wrapIRMethod g this irjt nativsi nativnm fldnm fldsym =
  let nativargs = argDefs attrFinal (nativsi.{ argSigs <- tail, argJTs <- tail }) (getArgs g)
      fldstri = case fldsym.strsig of
        Strictness.S xs -> tail xs
        _               -> []
      -- how to detect strictness of result value?
      retstrict = all (_.isStrict) fldstri
      -- retlazy = if not (null fldstri) then (last fldstri).isStrict else False
      callit jex = JInvoke { args = [], jex = JExMem { jex, name = "call", targs = [] } }
      call = if retstrict then id else callit
  in
  JMethod
  { attr = attrs [JPublic]
  , gvars = []
  , jtype = nativsi.returnJT
  , name = nativnm
  , args = nativargs
  , body = JBlock $ pure $ JReturn $ call $ invokeIR irjt fldnm this $
      zipWith (\(_, _, _, x) s -> StriArg x s.isStrict) nativargs fldstri
  }

data StriArg = StriArg { name :: String, strict :: Bool }

striArgExpr :: StriArg -> JExpr
striArgExpr (StriArg{name, strict}) =
  let thunkjt = Ref { jname = JName { qual = "", base = "Thunk" }, gargs = [] }  -- TODO use the proper JType
      lazy s = JInvoke { args = [JAtom s], jex = JStMem { jt = thunkjt, name = "lazy", targs = [] } }
  in
  if strict then JAtom name else lazy name

invokeIR :: JType -> String -> JExpr -> [StriArg] -> JExpr
invokeIR jt name this args = JInvoke
  { args = this : map striArgExpr args
  , jex = JStMem { jt, name, targs = [] }
  }
